package com.zhj.wapper.java;


import com.zhj.util.ClassUtils;

import org.springframework.util.Assert;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.PrintStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.Collections;
import java.util.List;
import java.util.stream.Stream;

import javax.tools.Diagnostic;
import javax.tools.DiagnosticCollector;
import javax.tools.FileObject;
import javax.tools.ForwardingJavaFileManager;
import javax.tools.JavaCompiler;
import javax.tools.JavaFileManager;
import javax.tools.JavaFileObject;
import javax.tools.SimpleJavaFileObject;
import javax.tools.StandardJavaFileManager;
import javax.tools.ToolProvider;


public class JavaCompile {

    private static final String MAIN = "main";

    //存放编译之后的字节码(key:类全名,value:编译之后输出的字节码)
    private ByteJavaFileObject byteJavaFileObject;

    /**
     * 获取一个java代码编译器
     */
    private final JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();

    //存放编译过程中输出的信息
    private DiagnosticCollector<JavaFileObject> diagnosticsCollector = new DiagnosticCollector<>();

    /**
     * 获取到类的完全限定名
     */
    private String fullClassName;

    /**
     * 源码
     */
    private String sourceCode;

    /**
     * 获取字符串类加载器
     */
    public StringClassLoader classLoader;

    /**
     * 编译后生成的内存对象
     */
    private Object targetObj;

    /**
     * 编译耗时
     */
    private long compilerTaskTime;

    /**
     * 设置运行耗时
     */
    private long runTaskTime;

    /**
     * 执行结果
     */
    private String runResult;

    /**
     * 进行源码编译
     * @param sourceCode
     */
    public JavaCompile(String sourceCode) {
        this.sourceCode = sourceCode;
        this.fullClassName = ClassUtils.getFullClassName(sourceCode);
        this.classLoader = AccessController.doPrivileged((PrivilegedAction<StringClassLoader>) StringClassLoader::new);
        boolean compiler = this.compiler();
        Assert.isTrue(compiler, String.format("代码编译失败：%s", this.getCompileMessage()));
    }

    /**
     * 编译代码
     *
     * @return
     */
    private boolean compiler() {
        Assert.noNullElements(new Object[] {sourceCode, fullClassName, classLoader}, "初始化构建参数不能为空，请调用init()方法进行初始化");
        long startTime = System.currentTimeMillis();
        StandardJavaFileManager standardFileManager = compiler.getStandardFileManager(diagnosticsCollector, null, null);
        StringJavaFileManage javaFileManage = new StringJavaFileManage(standardFileManager);
        //构造源码对象
        StringJavaFileObject javaFileObject = new StringJavaFileObject(fullClassName, sourceCode);

        //获取到编译任务
        JavaCompiler.CompilationTask task = compiler.getTask(null,
                                                             javaFileManage,
                                                             diagnosticsCollector,
                                                             null,
                                                             null,
                                                             Collections.singletonList(javaFileObject));
        compilerTaskTime = System.currentTimeMillis() - startTime;
        //开始编译
        return task.call();
    }

    /**
     * 执行main方法
     */
    public void runMainMethod(String... args) throws InvocationTargetException, IllegalAccessException, NoSuchMethodException,
                                                     ClassNotFoundException {
        runMethod(MAIN, new Object[] { args });
    }

    /**
     * 执行指定方法
     */
    public <T> Object runMethod(String methodName, Object... params) throws InvocationTargetException,
                                                                            IllegalAccessException,
                                                                            NoSuchMethodException,
                                                                            ClassNotFoundException {
        PrintStream out = System.out;
        Object result = null;
        try {
            long startTime = System.currentTimeMillis();
            ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
            PrintStream printStream = new PrintStream(outputStream);
            System.setOut(printStream);

            Class<?> classObj = classLoader.findClass(fullClassName);
            targetObj = classObj.newInstance();
            //确定参数的类型
            Class<?>[] paramsType = determineParamsType(params);
            Method main = targetObj.getClass().getDeclaredMethod(methodName, paramsType);
            result = main.invoke(targetObj, params);
            runTaskTime = System.currentTimeMillis() - startTime;
            //设置打印输出内容
            byte[] bytes = outputStream.toByteArray();
            runResult = new String(bytes, StandardCharsets.UTF_8);
        } catch (InstantiationException e) {
            e.printStackTrace();
        } finally {
            //还原默认打印的对象
            System.setOut(out);
            System.out.println(runResult);
        }
        return result;
    }

    /**
     * 根据jar执行对应的class方法
     */
    public void compileForJar() {

    }

    /**
     * 确定参数的类型
     *
     * @param params
     * @return
     */
    private Class<?>[] determineParamsType(Object[] params) {
        Class<?>[] classes = new Class<?>[params.length];
        Stream.iterate(0, n -> n + 1).limit(params.length).forEach(i -> classes[i] = params[i].getClass());
        return classes;
    }

    /**
     * 获取到编译信息
     *
     * @return
     */
    public String getCompileMessage() {
        StringBuilder builder = new StringBuilder();
        List<Diagnostic<? extends JavaFileObject>> diagnostics = diagnosticsCollector.getDiagnostics();
        for (Diagnostic<? extends JavaFileObject> diagnostic : diagnostics) {
            builder.append(diagnostic.toString()).append("\r\n");
        }
        return builder.toString();
    }

    /**
     * 获取到执行结果，包括打印的控制台信息
     *
     * @return
     */
    public String getRunResult() {
        return this.runResult;
    }

    /**
     * 创建一个自定义的java字符串文件对象
     */
    private static class StringJavaFileObject extends SimpleJavaFileObject {
        /**
         * 源码
         */
        private final String sourceCode;

        /**
         * 类名
         */
        private final String className;

        /**
         * Construct a SimpleJavaFileObject of the given kind and with the
         * given URI.
         */
        public StringJavaFileObject(String className, String sourceCode) {
            super(URI.create(String.format("string:///%s", className.replaceAll("\\.", "/") + Kind.SOURCE.extension)), Kind.SOURCE);
            this.sourceCode = sourceCode;
            this.className = className;
        }

        //字符串源码会调用该方法
        @Override
        public CharSequence getCharContent(boolean ignoreEncodingErrors) throws IOException {
            return sourceCode;
        }
    }

    /**
     * 自定义一个编译之后的字节码对象。将字符串编译成字节数组对象
     */
    private static class ByteJavaFileObject extends SimpleJavaFileObject {
        //存放编译后的字节码
        private ByteArrayOutputStream outPutStream;

        public ByteJavaFileObject(String className, Kind kind) {
            super(URI.create("string:///" + className.replaceAll("\\.", "/") + Kind.SOURCE.extension), kind);
        }

        //StringJavaFileManage 编译之后的字节码输出会调用该方法（把字节码输出到outputStream）
        @Override
        public OutputStream openOutputStream() {
            outPutStream = new ByteArrayOutputStream();
            return outPutStream;
        }

        //在类加载器加载的时候需要用到
        public byte[] getCompiledBytes() {
            return outPutStream.toByteArray();
        }
    }


    /**
     * 主要控制编译之后字节码的输出位置
     */
    private class StringJavaFileManage extends ForwardingJavaFileManager {

        public StringJavaFileManage(JavaFileManager javaFileManager) {
            super(javaFileManager);
        }

        @Override
        public JavaFileObject getJavaFileForOutput(Location location, String className, JavaFileObject.Kind kind, FileObject sibling) throws IOException {
            byteJavaFileObject = new ByteJavaFileObject(className, kind);
            return byteJavaFileObject;
        }
    }

    /**
     * 自定义类加载器，用来加载动态字节码
     * 注：当前自定义classLoader如果遇到内部类，会出现访问权限问题，因为内部类对象是由默认加载器进行加载
     */
    public class StringClassLoader extends ClassLoader {

        /**
         * 复写查询到class类
         *
         * @param name
         * @return
         * @throws ClassNotFoundException
         */
        @Override
        protected Class<?> findClass(String name) throws ClassNotFoundException {
            if (byteJavaFileObject != null) {
                byte[] compiledBytes = byteJavaFileObject.getCompiledBytes();
                return defineClass(name, compiledBytes, 0, compiledBytes.length);
            }
            try {
                //通过双亲委派机制进行类的查找，并且尝试去加载类
                return ClassLoader.getSystemClassLoader().loadClass(name);
            } catch (Exception e) {
                return super.findClass(name);
            }
        }
    }

}
